# http://proceedings.mlr.press/v101/huang19a/huang19a.pdf
# https://www.researchgate.net/publication/220875351_Generative_Models_for_Labeling_Multi-object_Configurations_in_Images
# https://www.tensorflow.org/datasets/catalog/open_images_v4
# Auto-Encoding Progressive Generative Adversarial Networks For 3D Multi Object Scenes
TODO
datasets to experiment
%config Completer.use_jedi = False
from ipywidgets import IntProgress
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import logging
import tensorflow_datasets as tfds
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from sklearn.mixture import GaussianMixture
import os
seed = 1
np.random.seed(1)
tf.random.set_seed(1)
batch_size = 32
epochs = 10
dataset_name = 'wider_face'
if dataset_name == 'bdd100k':
train_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/train1/',batch_size=batch_size)# train
test_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/test1/',batch_size=batch_size) # test
validation_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/val1/',batch_size=batch_size) # validation
elif dataset_name in ['flic','fashion_mnist','mnist','kitti']:
train_ds,test_ds = tfds.load(name=dataset_name,split=['train', 'test']\
,as_supervised=False,download=True)
validation_ds = test_ds
elif dataset_name in ['wider_face']:
train_ds,test_ds,validation_ds = tfds.load(name=dataset_name,split=['train', 'test','validation']\
,as_supervised=False,download=True)
else:
raise ValueError(f'Unhandled dataset {dataset_name}')
if dataset_name == 'bdd100k':
dims = [x[0].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['batch','height','width','depth'])
else:
dims = [x['image'].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['height','width','depth'])
dims_df.describe()
| height | width | depth | |
|---|---|---|---|
| count | 12880.000000 | 12880.0 | 12880.0 |
| mean | 888.309627 | 1024.0 | 3.0 |
| std | 350.513446 | 0.0 | 0.0 |
| min | 171.000000 | 1024.0 | 3.0 |
| 25% | 682.000000 | 1024.0 | 3.0 |
| 50% | 760.000000 | 1024.0 | 3.0 |
| 75% | 1024.000000 | 1024.0 | 3.0 |
| max | 9108.000000 | 1024.0 | 3.0 |
m = 20
height = int(min(dims_df['height'])/m)*m
width = int(min(dims_df['width'])/m)*m
# height = 2**(int(np.log2(min(dims_df['height']))))
# width = 2**(int(np.log2(min(dims_df['width']))))
depth = min(dims_df['depth'])
height,width = min(height,width),min(height,width)
height,width,depth
(160, 160, 3)
for t in train_ds.take(3):
print(t.keys())
dict_keys(['faces', 'image', 'image/filename']) dict_keys(['faces', 'image', 'image/filename']) dict_keys(['faces', 'image', 'image/filename'])
if dataset_name == 'bdd100k':
train_ds = train_ds.map(lambda x0,x1: x0/255.)
test_ds = test_ds.map(lambda x0,x1: x0/255.)
validation_ds = validation_ds.map(lambda x0,x1: x0/255.)
else:
train_ds = train_ds.map(lambda x: tf.image.resize(images=tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
train_ds = train_ds.batch(batch_size,drop_remainder=True)
###
test_ds = test_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
test_ds = test_ds.batch(batch_size,drop_remainder=True)
###
validation_ds = validation_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.\
,size=[height,width]))
validation_ds = validation_ds.batch(batch_size,drop_remainder=True)
###
train_ds_double_zipped = tf.data.Dataset.zip(datasets=(train_ds,train_ds))
test_ds_double_zipped = tf.data.Dataset.zip(datasets=(test_ds,test_ds))
validation_ds_double_zipped = tf.data.Dataset.zip(datasets=(validation_ds,validation_ds))
latent_dim = 4096
class CAE(tf.keras.Model):
"""Convolutional variational autoencoder."""
def __init__(self, latent_dim):
super(CAE, self).__init__()
self.latent_dim = latent_dim
self.logger = logging.getLogger('CAE')
self.encoder = tf.keras.Sequential(name='encoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(height, width, depth)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim),
]
)
self.decoder = tf.keras.Sequential(name='decoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=int(height/4) * int(width/4) * 32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(int(height/4), int(width/4), 32)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='same',
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=3, strides=2, padding='same',
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=depth, kernel_size=3, strides=1, padding='same'),
]
)
self.encoder.summary()
self.decoder.summary()
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
cae = CAE(latent_dim)
cae.compile(optimizer='adam', loss=losses.MeanSquaredError())
Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 79, 79, 32) 896 _________________________________________________________________ conv2d_1 (Conv2D) (None, 39, 39, 64) 18496 _________________________________________________________________ flatten (Flatten) (None, 97344) 0 _________________________________________________________________ dense (Dense) (None, 1024) 99681280 ================================================================= Total params: 99,700,672 Trainable params: 99,700,672 Non-trainable params: 0 _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 51200) 52480000 _________________________________________________________________ reshape (Reshape) (None, 40, 40, 32) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 80, 80, 64) 18496 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 160, 160, 32) 18464 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 160, 160, 3) 867 ================================================================= Total params: 52,517,827 Trainable params: 52,517,827 Non-trainable params: 0 _________________________________________________________________
model_file_path = f'./models/cae_dataset_{dataset_name}_z_dim_{latent_dim}_data_dim_{height}x{width}x{depth}'
print(f'model path = {model_file_path}')
model path = ./models/cae_dataset_wider_face_z_dim_1024_data_dim_160x160x3
if os.path.exists(model_file_path):
print('loading saved model')
cae = tf.keras.models.load_model(filepath=model_file_path)
else:
print('building model')
# use checkpoints to save model fitting progress
# https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
checkpoint_filepath = './checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
cae.fit(x=train_ds_double_zipped,validation_data=test_ds_double_zipped,epochs=epochs,\
callbacks=[model_checkpoint_callback],use_multiprocessing=True)
# The model weights (that are considered the best) are loaded into the model.
cae.load_weights(checkpoint_filepath)
print('saving model')
cae.save(filepath=model_file_path)
building model Epoch 1/10 402/402 [==============================] - 524s 1s/step - loss: 0.0799 - val_loss: 0.0392 Epoch 2/10 402/402 [==============================] - 2946s 7s/step - loss: 0.0379 - val_loss: 0.0314 Epoch 3/10 402/402 [==============================] - 1353s 3s/step - loss: 0.0293 - val_loss: 0.0272 Epoch 4/10 402/402 [==============================] - 4426s 11s/step - loss: 0.0255 - val_loss: 0.0238 Epoch 5/10 402/402 [==============================] - 2632s 7s/step - loss: 0.0231 - val_loss: 0.0223 Epoch 6/10 402/402 [==============================] - 2468s 6s/step - loss: 0.0214 - val_loss: 0.0211 Epoch 7/10 402/402 [==============================] - 592s 1s/step - loss: 0.0199 - val_loss: 0.0203 Epoch 8/10 402/402 [==============================] - 575s 1s/step - loss: 0.0188 - val_loss: 0.0194 Epoch 9/10 402/402 [==============================] - 546s 1s/step - loss: 0.0180 - val_loss: 0.0186 Epoch 10/10 402/402 [==============================] - 1522s 4s/step - loss: 0.0170 - val_loss: 0.0187 saving model INFO:tensorflow:Assets written to: ./models/cae_dataset_wider_face_z_dim_1024_data_dim_160x160x3/assets
INFO:tensorflow:Assets written to: ./models/cae_dataset_wider_face_z_dim_1024_data_dim_160x160x3/assets
# create valdation dataset tensor
for e in validation_ds.take(1):
initial_state = tf.zeros(dtype=tf.float32,shape=e.shape)
validation_ds_tensor = validation_ds.\
reduce(initial_state=initial_state,reduce_func=lambda x,y: tf.concat(values=[x,y],axis=0))
validation_ds_tensor = validation_ds_tensor[batch_size:] # drop dummy initial state
# calculate loss, can be compare over different dataset due to data scaling from 0 to 1
y_predicted = cae.predict(validation_ds)
cae_loss = cae.loss(y_pred=y_predicted,y_true=validation_ds_tensor).numpy()
print(f'CAE loss for dataset {dataset_name} = {np.round(cae_loss,4)}')
CAE loss for dataset wider_face = 0.03909999877214432
# plot decoded images
for batch in validation_ds.take(2):
z = cae.encoder(batch).numpy()
decoded_imgs = cae.decoder(z).numpy()
for i in range(batch.shape[0]):
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(batch[i])
ax2.imshow(decoded_imgs[i])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# getting z tensor
z_tensor = None
inf_or_unknown_cardinality = ((test_ds.cardinality()==tf.data.INFINITE_CARDINALITY)\
or (test_ds.cardinality() == tf.data.UNKNOWN_CARDINALITY)).numpy()
batches = test_ds.cardinality().numpy() if not inf_or_unknown_cardinality else 500
with tqdm(total=batches) as pbar:
for batch in test_ds.take(batches):
z = cae.encoder(batch).numpy()
if z_tensor is None:
z_tensor = tf.convert_to_tensor(z)
else:
z_tensor = tf.concat([z_tensor,tf.convert_to_tensor(z)],axis=0)
pbar.update(1)
#print(f'z shape {z.shape}')
# decoded_imgs = cae.decoder(z).numpy()
# #print(f'decoded images shape {decoded_imgs[0].shape}')
# plt.imshow(batch[0])
# plt.show()
# plt.imshow(decoded_imgs[0])
# plt.show()
z_tensor.shape
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
TensorShape([16096, 1024])
z_np= z_tensor.numpy()
n_z = z_np.shape[0]
n_z_train = int(0.8*n_z)
z_train = z_np[:n_z_train]
z_test = z_np[n_z_train:]
random_state = 1
reg_covar = 0.1
logps = []
k_values = [1,10,20,50,70,80,100,200]
cov_types = ['diag','cov']
for k in k_values:
for cov_type in ['diag','full']:
try:
gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\
reg_covar=reg_covar).fit(z_train)
logp_gm = gm_fit.score(X=z_test)
print(f'For Gaussin Mixture with k = {k} and cov type {cov_type}, logp = {logp_gm} ')
logps.append({'k':k,'cov_type':cov_type,'logp':logp_gm})
print('############## ')
except Exception as e:
print(f'Catched expection {e} ')
For Gaussin Mixture with k = 1 and cov type diag, logp = -411.81840966578113 ############## For Gaussin Mixture with k = 1 and cov type full, logp = 158.8647315390892 ############## For Gaussin Mixture with k = 10 and cov type diag, logp = -303.1187019192367 ############## For Gaussin Mixture with k = 10 and cov type full, logp = 160.84260497329603 ############## For Gaussin Mixture with k = 20 and cov type diag, logp = -281.46386500449336 ############## For Gaussin Mixture with k = 20 and cov type full, logp = 161.0652689730855 ############## For Gaussin Mixture with k = 50 and cov type diag, logp = -253.4661870891009 ##############
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-17-139827c4bcf3> in <module> 8 try: 9 gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\ ---> 10 reg_covar=reg_covar).fit(z_train) 11 logp_gm = gm_fit.score(X=z_test) 12 print(f'For Gaussin Mixture with k = {k} and cov type {cov_type}, logp = {logp_gm} ') /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py in fit(self, X, y) 191 self 192 """ --> 193 self.fit_predict(X, y) 194 return self 195 /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py in fit_predict(self, X, y) 243 prev_lower_bound = lower_bound 244 --> 245 log_prob_norm, log_resp = self._e_step(X) 246 self._m_step(X, log_resp) 247 lower_bound = self._compute_lower_bound( /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py in _e_step(self, X) 296 the point of each sample in X. 297 """ --> 298 log_prob_norm, log_resp = self._estimate_log_prob_resp(X) 299 return np.mean(log_prob_norm), log_resp 300 /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py in _estimate_log_prob_resp(self, X) 501 logarithm of the responsibilities 502 """ --> 503 weighted_log_prob = self._estimate_weighted_log_prob(X) 504 log_prob_norm = logsumexp(weighted_log_prob, axis=1) 505 with np.errstate(under='ignore'): /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py in _estimate_weighted_log_prob(self, X) 454 weighted_log_prob : array, shape (n_samples, n_component) 455 """ --> 456 return self._estimate_log_prob(X) + self._estimate_log_weights() 457 458 @abstractmethod /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_gaussian_mixture.py in _estimate_log_prob(self, X) 694 def _estimate_log_prob(self, X): 695 return _estimate_log_gaussian_prob( --> 696 X, self.means_, self.precisions_cholesky_, self.covariance_type) 697 698 def _estimate_log_weights(self): /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_gaussian_mixture.py in _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type) 409 log_prob = np.empty((n_samples, n_components)) 410 for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): --> 411 y = np.dot(X, prec_chol) - np.dot(mu, prec_chol) 412 log_prob[:, k] = np.sum(np.square(y), axis=1) 413 <__array_function__ internals> in dot(*args, **kwargs) KeyboardInterrupt:
logps_df = pd.DataFrame.from_records(data=logps)
logps_df.sort_values(by='logp',ascending=False).reset_index()